Source code for hysop.backend.device.codegen.base.utils

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import hashlib


[docs] class WriteOnceDict(dict): def __init__(self, **kargs): super().__init__(**kargs) self.lock() def __setitem__(self, key, val): if (not self.allow_overwrites) and (key in self.keys()): raise RuntimeError( f"Key {key} already in use for variable {str(self[key])}!" ) super().__setitem__(key, val)
[docs] def lock(self): self.allow_overwrites = False
[docs] def release(self): self.allow_overwrites = True
[docs] def translate(self, key2key_dict): out = WriteOnceDict() for k in key2key_dict.keys(): out[k] = self[key2key_dict[k]] return out
[docs] class ReadDefaultWriteOnceDict(WriteOnceDict): def __init__(self, default_val, *args, **kargs): super().__init__(*args, **kargs) self.default_val = default_val def __getitem__(self, key): if key not in self.keys(): return self.default_val else: return super().__getitem__(key)
[docs] class VarDict(WriteOnceDict): def __setitem__(self, key, val): from hysop.backend.device.codegen.base.variables import CodegenVariable if not isinstance(key, str): raise TypeError("VarDict key should be a string!") elif not isinstance(val, CodegenVariable): raise TypeError("VarDict value should inherit CodegenVariable!") else: super().__setitem__(key, val)
[docs] class ArgDict(WriteOnceDict): def __init__(self, overloading_allowed=False, *args, **kargs): super().__init__(*args, **kargs) self.arg_order = [] self.overloading_allowed = overloading_allowed def __setitem__(self, key, val): from hysop.backend.device.codegen.base.variables import CodegenVariable if not isinstance(key, str): raise TypeError("ArgDict key should be a string!") elif not isinstance(val, CodegenVariable): raise TypeError("ArgDict value should inherit CodegenVariable!") else: if key in self.keys(): append = False else: append = True super().__setitem__(key, val) if append: self.arg_order.append(key)
[docs] def items(self): return iter( [(argname, self.__getitem__(argname)) for argname in self.arg_order] )
[docs] def update(self, other): for key, val in other.items(): self[key] = val return self
[docs] def build_args(self): function_proto_args = [] function_impl_args = [] constant_args = [] i = 0 for varname in self.arg_order: var = self[varname] if var.symbolic_mode and var.known(): constant_args.append(var) elif var.is_symbolic(): prototype_arg = var.argument(impl=False) implementation_arg = var.argument(impl=True) function_proto_args.append(prototype_arg) function_impl_args.append(implementation_arg) i += 1 else: assert var.known() assert var.symbolic_mode == False assert var.is_symbolic() == False if len(function_impl_args) and len(function_proto_args[-1]): if function_proto_args[-1][-1] == "\n": function_proto_args[-1] = function_proto_args[-1][:-1] if function_impl_args[-1][-1] == "\n": function_impl_args[-1] = function_impl_args[-1][:-1] return function_proto_args, function_impl_args, constant_args
[docs] def function_name_suffix(self, return_type, known_args): if not self.overloading_allowed: return self.codegen_name_suffix(return_type, known_args) suffix = f"({return_type})_" for varname in self.arg_order: var = self[varname] if not var.is_symbolic(): suffix += f"_{var.name}={var.sval()}" elif known_args and (varname in known_args): tmp = var.copy() tmp.set_value(known_args[varname]) suffix += f"_{var.name}={tmp.sval()}" if suffix != "": return "_" + self.hash(suffix) else: return ""
# handle type function overloading
[docs] def codegen_name_suffix(self, return_type, known_args): suffix = f"({return_type})_" for varname in self.arg_order: var = self[varname] if not var.is_symbolic(): suffix += f"_({var.ctype}){var.name}={var.sval()}" elif known_args and (varname in known_args): tmp = var.copy() tmp.set_value(known_args[varname]) suffix += f"_({var.ctype}){var.name}={tmp.sval()}" else: suffix += f"_({var.ctype}){var.name}" if suffix != "": return "_" + self.hash(suffix) else: return ""
# robust with up to 256 functions with the same basename # max_fun = sqrt(16**nb) = 2**(2*nb)
[docs] def hash(self, string): return hashlib.sha1(string.encode("utf-8")).hexdigest()[:4]
[docs] class SortedDict(dict): @classmethod def _key(cls, k): if hasattr(k, "name"): s = k.name else: s = str(k) return s
[docs] def keys(self): keys = super().keys() return list(sorted(keys, key=self._key))
[docs] def iterkeys(self): keys = super().keys() return iter(sorted(keys, key=self._key))
[docs] def values(self): return list(self[k] for k in self.keys())
[docs] def itervalues(self): return iter(self[k] for k in self.keys())
def items(self): return tuple((k, self[k]) for k in self.keys())
[docs] def items(self): return iter((k, self[k]) for k in self.keys())
def __iter__(self): return self.iterkeys()